from LeetTool import TreeNode


class Solution:
    def __init__(self):
        self.ans = 0

    def findTilt(self, root: TreeNode) -> int:
        self.dfs(root)
        return self.ans

    def dfs(self, node):
        if not node:
            return 0
        sum_left = self.dfs(node.left)  # 统计左子树的节点树
        sum_right = self.dfs(node.right) # 统计右子树的节点树
        self.ans += abs(sum_left - sum_right)
        return sum_left + sum_right + node.val


if __name__ == "__main__":
    print()
